/*
 * @Author: Wangjun
 * @Date: 2023-03-21 09:11:19
 * @LastEditTime: 2024-06-05 09:31:33
 * @LastEditors: wangjun haodreams@163.com
 * @Description:
 * @FilePath: \aexp\node.go
 * hnxr
 */
package aexp

import (
	"bytes"
	"errors"
	"go/ast"
	"go/token"
	"strconv"
)

type Get func(string) (any, error)

func ParseNumber(b *ast.BasicLit) (n float64, err error) {
	switch b.Kind {
	case token.INT:
		ival, err := strconv.Atoi(b.Value)
		if err != nil {
			return float64(0), err
		}
		return float64(ival), nil
	case token.FLOAT:
		fval, err := strconv.ParseFloat(b.Value, 64)
		if err != nil {
			return float64(0), err
		}
		return float64(fval), nil
	}

	return float64(0), errors.New("不是有效的数字")
}

type unaryNot struct {
}

func (m *unaryNot) String() string {
	return "!"
}

func (m *unaryNot) Eval(x any) (any, error) {
	switch x1 := x.(type) {
	case float64:
		if x1 == 0 {
			return float64(1), nil
		}
		return 0, nil
	case []float64:
		v := make([]float64, len(x1))
		for i := range v {
			if x1[i] == 0 {
				v[i] = 1
			}
		}
		return v, nil
	}
	return nil, errors.New("'!' 参数错误")
}

type unarySub struct {
}

func (m *unarySub) String() string {
	return "-"
}

func (m *unarySub) Eval(x any) (any, error) {
	switch x1 := x.(type) {
	case float64:
		return -x1, nil
	case []float64:
		v := make([]float64, len(x1))
		for i := range v {
			v[i] = -x1[i]
		}
		return v, nil
	}
	return nil, errors.New("'-' 参数错误")
}

type UnaryOPer interface {
	String() string
	Eval(x any) (any, error)
}
type UnaryExpr struct {
	x  Evaler
	op UnaryOPer
}

func (m *UnaryExpr) String() string {
	return ""
}

func (m *UnaryExpr) Eval() (any, error) {
	x, err := m.x.Eval()
	if err != nil {
		return nil, err
	}
	return m.op.Eval(x)
}

// () 括号运算
type ParenExpr struct {
	x Evaler
}

func (m *ParenExpr) Eval() (any, error) {
	return m.x.Eval()
}

func (m *ParenExpr) String() string {
	return m.x.String()
}

type BinaryExpr struct {
	x, y Evaler
	op   OPer
}

func (m *BinaryExpr) String() string {
	return m.op.String()
}

func (m *BinaryExpr) Eval() (any, error) {
	x, err := m.x.Eval()
	if err != nil {
		return nil, err
	}
	//增加如果是逻辑and 或者逻辑或支持短路输出
	if _, ok := m.op.(*land); ok {
		if v, ok := x.(float64); ok {
			if int(v) == 0 {
				return float64(0), nil
			}
		}
	} else if _, ok := m.op.(*lor); ok {
		if v, ok := x.(float64); ok {
			if int(v) != 0 {
				return float64(1), nil
			}
		}
	}
	y, err := m.y.Eval()
	if err != nil {
		return nil, err
	}

	return m.op.Eval(x, y)
}

type Ident struct {
	ctx  *Context
	name string
}

func (m *Ident) Eval() (any, error) {
	v, err := m.ctx.get(m.name)
	if err != nil {
		return v, err
	}
	if _, ok := v.(string); ok {
		return v, err
	} else if val, ok := v.(*string); ok {
		return *val, err
	}

	return ToNumber[float64](v)
}

func (m *Ident) String() string {
	return m.name
}

type CallExpr struct {
	args   []Evaler
	fvals  []any
	name   string
	method Function
}

func (m *CallExpr) String() string {
	buf := bytes.NewBufferString(m.name)
	buf.WriteByte('(')
	for i, arg := range m.args {
		if i != 0 {
			buf.WriteByte(',')
		}
		buf.WriteString(arg.String())
	}
	buf.WriteByte(')')
	return buf.String()
}

func (m *CallExpr) Eval() (any, error) {
	var err error
	for i, e := range m.args {
		m.fvals[i], err = e.Eval()
		if err != nil {
			return nil, err
		}
	}

	return m.method(m.fvals...)
}

var ErrNan = errors.New("not a number")

// 如果是number 转换为 float64 或者int64
func ToNumber[T int64 | float64](a any) (T, error) {
	switch v := a.(type) {
	case bool:
		if v {
			return 1, nil
		}
		return 0, nil
	case *bool:
		if *v {
			return 1, nil
		}
		return 0, nil
	case int8:
		return T(v), nil
	case byte:
		return T(v), nil
	case int:
		return T(v), nil
	case int16:
		return T(v), nil
	case int32:
		return T(v), nil
	case int64:
		return T(v), nil
	case *int:
		return T(*v), nil
	case *int16:
		return T(*v), nil
	case *int32:
		return T(*v), nil
	case *int64:
		return T(*v), nil
	case uint:
		return T(v), nil
	case uint16:
		return T(v), nil
	case uint32:
		return T(v), nil
	case uint64:
		return T(v), nil
	case *uint:
		return T(*v), nil
	case *uint16:
		return T(*v), nil
	case *uint32:
		return T(*v), nil
	case *uint64:
		return T(*v), nil
	case float32:
		return T(v), nil
	case float64:
		return T(v), nil
	case *float32:
		return T(*v), nil
	case *float64:
		return T(*v), nil
	}
	return 0, ErrNan
}
